Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interpretability module: sparse linear models via LASSO #120

Merged
merged 26 commits into from
Jun 11, 2024

Conversation

luigibonati
Copy link
Owner

@luigibonati luigibonati commented Jan 30, 2024

Description

Add sparse linear models optimized via LASSO as tools for interpreting the CVs and/or the resulting states, as done here: https://pubs.acs.org/doi/abs/10.1021/acs.jctc.2c00393.

I started from the notebook that @pietronvll and I did. We implemented both the classifier case (as done in stateinterpreter) and also the regression one. A few changes:

  • I extended the functions to work also for the multi-class case
  • I changed the scoring function to use the balanced_accuracy_score instead of the standard one in case the datasets are imbalanced.

For both the regression and classification the signature is (almost) the same, with both returning the optimized estimator together with the list of non-zero features and their coefficients. I also did separate functions to plot the results (coefficient paths, score and number of features).

Todos

Notable points that this PR has either accomplished or will accomplish.

  • Function: lasso_classification (based on sckitlearn.LogisticRegressionCV)
  • Function: lasso_regression (based on sckitlearn.LassoCV)
  • Plotting functions
  • Docstrings
  • Regtests
  • Raise error when importing module if scikit-learn is not installed
  • Add documentation pages
  • Add scikit-learn dependency to GA

Tutorials

Work in progress

Questions

  • This requires scikit-learn as an additional dependency, which I would keep optional
  • As of now, I put these functions inside utils.lasso. However, since there is already also the sensitivity analysis contained in utils.explain we might move all these functions into a new module called explain?

Status

  • Ready to go

mlcolvar/utils/lasso.py Fixed Show fixed Hide fixed
mlcolvar/utils/lasso.py Fixed Show fixed Hide fixed
mlcolvar/utils/lasso.py Fixed Show fixed Hide fixed
mlcolvar/utils/lasso.py Fixed Show fixed Hide fixed
mlcolvar/utils/lasso.py Fixed Show fixed Hide fixed
mlcolvar/utils/lasso.py Fixed Show fixed Hide fixed
mlcolvar/utils/lasso.py Fixed Show fixed Hide fixed
mlcolvar/utils/lasso.py Fixed Show fixed Hide fixed
mlcolvar/utils/lasso.py Fixed Show fixed Hide fixed
Copy link

codecov bot commented Jan 30, 2024

Codecov Report

Attention: Patch coverage is 91.90751% with 28 lines in your changes missing coverage. Please review.

Project coverage is 92.50%. Comparing base (3f9adeb) to head (71ad599).

Additional details and impacted files


import matplotlib
import matplotlib.pyplot as plt
import mlcolvar.utils.plot

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'mlcolvar' is not used.
import mlcolvar.utils.plot

try:
import sklearn

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'sklearn' is not used.
@@ -1,8 +1,10 @@
import numpy as np
import torch
from matplotlib import patches as mpatches
import matplotlib.pyplot as plt
import mlcolvar.utils.plot

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'mlcolvar' is not used.
mlcolvar/explain/utils.py Fixed Show fixed Hide fixed
mlcolvar/explain/lasso.py Fixed Show fixed Hide fixed
mlcolvar/explain/lasso.py Fixed Show fixed Hide fixed
mlcolvar/explain/lasso.py Fixed Show fixed Hide fixed
mlcolvar/explain/lasso.py Fixed Show fixed Hide fixed
mlcolvar/explain/lasso.py Fixed Show fixed Hide fixed
mlcolvar/explain/lasso.py Fixed Show fixed Hide fixed
mlcolvar/explain/lasso.py Fixed Show fixed Hide fixed
@@ -0,0 +1,7 @@
import pytest

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'pytest' is not used.
try:
import sklearn
except ImportError:
print('The lasso module requires scikit-learn as additional dependency.')

Check notice

Code scanning / CodeQL

Use of a print statement at module level Note

Print statement may execute during import.
@luigibonati luigibonati removed the request for review from pietronvll June 6, 2024 11:50
mlcolvar/explain/utils.py Fixed Show fixed Hide fixed
fig, axs = plt.subplots(n_feat, 1, figsize=(3, 3*n_feat))

plt.suptitle('Features distribution')
init_ax = True

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable init_ax is not used.
ax.set_xlim(0, None)
if n_feat != len(axs):
raise ValueError(f'Number of features ({len(features)}) != number of axis ({len(axs)})')
init_ax = False

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable init_ax is not used.
mlcolvar/explain/utils.py Fixed Show fixed Hide fixed
mlcolvar/explain/utils.py Fixed Show fixed Hide fixed
@luigibonati
Copy link
Owner Author

I have put everything into a new explain submodule, containing sensitivity analysis and sparse models

will merge it soon

@luigibonati luigibonati merged commit d4fb5d7 into main Jun 11, 2024
12 checks passed
@luigibonati luigibonati deleted the interpretability branch June 11, 2024 15:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants